
import pandas as pd
from _funcs import parallel_read_csv, get_chunks, read_csv_chunk

dir = r'/Volumes/Zihao_SSD2/PatentsView/'

## ===========================================================================================================
## Robustness check: Choose cutoff based on cosine similarity
## Zihao Li. 06/2024

## Inputs:  g_us_patent_citation.tsv (from PatentsView)
##          cleandata/sim_score_1981_2015_cutoff.csv

## Outputs: temp/omission_panel5_cutoff.csv
## ===========================================================================================================

def main():
    ## Load actual citation data
    print('Loading actual citation data...')
    df_actual = parallel_read_csv(dir + 'rawdata/g_us_patent_citation.tsv', file_type='tsv')
    print('Shape of actual_citation dataset is {}.'.format(df_actual.shape)) # (128401915, 7)

    ## Special characters and non-numerics
    print('Removing non-numeric patents and converting patent_id to numeric...')
    df_actual = df_actual[~df_actual["patent_id"].str.contains("[a-zA-Z]")]
    mask = df_actual["citation_patent_id"].str.contains("[^0-9.]")
    mask = mask.astype(bool)
    df_actual = df_actual[~mask]
    df_actual["patent_id"] = pd.to_numeric(df_actual["patent_id"], errors="raise")
    df_actual["citation_patent_id"] = pd.to_numeric(df_actual["citation_patent_id"], errors="raise")
    print('Shape of actual_citation dataset (after removing non-numeric patents) is {}.'.format(df_actual.shape)) # (113603267, 7)

    ## Generate actual citation list
    print('Generating actual citation list...')
    df_actual_lst = df_actual.groupby('patent_id').agg({'citation_patent_id': lambda x: x.tolist()}).reset_index().sort_values(by=['patent_id'])
    df_actual_lst['citation_patent_id'] = [row[0] if isinstance(row[0], list) else row for row in df_actual_lst['citation_patent_id']]
    df_actual_lst = df_actual_lst.rename(columns={'citation_patent_id': 'actual_citation_list'})
    df_actual_lst['num_citations'] = df_actual_lst["actual_citation_list"].apply(lambda x: len(x))
    print('Shape of actual_citation dataset (after grouping by patent_id) is {}.'.format(df_actual_lst.shape)) # (6719417, 3)
    del df_actual

    # Load artificial citation data
    print('Loading artificial citation data...')
    df_artificial = pd.read_csv(dir + 'cleandata/sim_score_1981_2015_cutoff.csv')
    df_artificial = df_artificial.drop(columns=['patent_idx', 'cited_patent_idx', 'patent_year', 'cited_patent_year'])
    print(df_artificial.shape) # (45610951, 3)

    # Merge
    print('Merging actual_citation_list with artificial citation data...')
    df = df_artificial.merge(df_actual_lst, how='left', on='patent_id'); del df_actual_lst, df_artificial 
    print('Shape of merged dataset is {}.'.format(df.shape)) # (45610951, 5)

    ## Generate omission index
    print('Generating omission index...')
    df = df.dropna(subset=['actual_citation_list'])
    df['omission'] = df.apply(lambda row: 0 if row['cited_patent_id'] in row['actual_citation_list'] else 1, axis=1)
    print('Shape of omission dataset is {}.'.format(df.shape)) # (42093135, 6)

    ## Export omission panel
    print('Exporting omission panel...')
    df.to_csv(dir + 'temp/omission_panel_cutoff.csv', index=False)


if __name__ == "__main__":
    main()